"""
This code validates Lmax/Lupper as the number of parameters grow
You can grow the VQA by adding more layers/qubits
"""

import itertools
import pennylane as qml
import pennylane.numpy as pnp
import matplotlib.pyplot as plt
import seaborn as sns

def create_qnn(n_layers, n_qubits, n_gates, observable_coeffs, observable_ops, entangled=True):
    dev = qml.device('default.qubit', wires=n_qubits)

    @qml.qnode(dev)
    def circuit(params):
        for layer in range(n_layers):
            for qubit in range(n_qubits):
                if n_gates == 1:
                    qml.RX(params[layer][qubit][0], wires=qubit)
                elif n_gates == 2:
                    qml.RX(params[layer][qubit][0], wires=qubit)
                    qml.RZ(params[layer][qubit][1], wires=qubit)
                elif n_gates == 3:
                    qml.RX(params[layer][qubit][0], wires=qubit)
                    qml.RZ(params[layer][qubit][1], wires=qubit)
                    qml.RY(params[layer][qubit][2], wires=qubit)

            if entangled:
                for qubit in range(n_qubits):
                    if n_qubits <= 1:
                        continue
                    next_qubit = (qubit + 1) % n_qubits
                    qml.CNOT(wires=[qubit, next_qubit])

        observable = qml.Hamiltonian(observable_coeffs, observable_ops)
        return qml.expval(observable)

    return circuit

def generate_parameter_samples(n_layers, n_qubits, n_samples, n_gates=2):
    pnp.random.seed(42)
    samples = [pnp.random.uniform(0, 2*pnp.pi, size=(n_layers, n_qubits, n_gates)) for _ in range(n_samples)]
    return pnp.array(samples)

def calculate_hessian_norms(qnn, samples):
    hessian_norms = []
    hessian_fn = qml.jacobian(qml.grad(qnn))
    for i, params in enumerate(samples):
        flat_params = params.flatten()
        def cost_fn_flat(p_flat):
            p_reshaped = p_flat.reshape(params.shape)
            return qnn(p_reshaped)

        hessian_matrix = qml.jacobian(qml.grad(cost_fn_flat))(flat_params)
        spectral_norm = pnp.linalg.norm(hessian_matrix, ord=2)
        hessian_norms.append(spectral_norm)

    return hessian_norms

# ==================================================================
if __name__ == '__main__':
    results_data = []
    n_samples = 1000
    n_gates = 3
    n_qubits = 1
    n_layer_combos = [1,5,10,15,20,25,30,35,40]
    observable_coeffs = [1 / n_qubits] * n_qubits 
    observable_ops = [qml.PauliZ(i) for i in range(n_qubits)] 

    for n_layers in n_layer_combos:
        qnn = create_qnn(n_layers, n_qubits, n_gates, observable_coeffs, observable_ops)
        samples = generate_parameter_samples(n_layers, n_qubits, n_samples, n_gates=n_gates)

        P = n_layers * n_qubits * n_gates
        norm_M = 1.0
        L_bound = P * norm_M

        hessian_norms = calculate_hessian_norms(qnn, samples)

        all_within_bound = all(norm <= L_bound for norm in hessian_norms)
        print("--- Experiment Setup ---")
        print(f"Number of Layers: {n_layers}, Number of Qubits: {n_qubits}, Number of Gates: {n_gates}, Total Parameters (P): {P}")
        print(f"Theoretical L-Smoothness Bound (L <= P): {L_bound:.4f}")
        print(f"Largest Hessian Norm: {pnp.max(hessian_norms)}")
        results_data.append((n_layers, n_qubits, n_gates, pnp.max(hessian_norms)))

    print(results_data)
